{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Native APIs\n",
    "\n",
    "Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n",
    "\n",
    "- `/generate` (text generation model)\n",
    "- `/get_model_info`\n",
    "- `/get_server_info`\n",
    "- `/health`\n",
    "- `/health_generate`\n",
    "- `/flush_cache`\n",
    "- `/update_weights`\n",
    "- `/encode`(embedding model)\n",
    "- `/classify`(reward model)\n",
    "\n",
    "We mainly use `requests` to test these APIs in the following examples. You can also use `curl`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Launch A Server"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sglang.utils import (\n",
    "    execute_shell_command,\n",
    "    wait_for_server,\n",
    "    terminate_process,\n",
    "    print_highlight,\n",
    ")\n",
    "\n",
    "import requests\n",
    "\n",
    "server_process = execute_shell_command(\n",
    "    \"\"\"\n",
    "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "wait_for_server(\"http://localhost:30010\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate (text generation model)\n",
    "Generate completions. This is similar to the `/v1/completions` in OpenAI API. Detailed parameters can be found in the [sampling parameters](../references/sampling_params.md)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"http://localhost:30010/generate\"\n",
    "data = {\"text\": \"What is the capital of France?\"}\n",
    "\n",
    "response = requests.post(url, json=data)\n",
    "print_highlight(response.json())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get Model Info\n",
    "\n",
    "Get the information of the model.\n",
    "\n",
    "- `model_path`: The path/name of the model.\n",
    "- `is_generation`: Whether the model is used as generation model or embedding model.\n",
    "- `tokenizer_path`: The path/name of the tokenizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"http://localhost:30010/get_model_info\"\n",
    "\n",
    "response = requests.get(url)\n",
    "response_json = response.json()\n",
    "print_highlight(response_json)\n",
    "assert response_json[\"model_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n",
    "assert response_json[\"is_generation\"] is True\n",
    "assert response_json[\"tokenizer_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n",
    "assert response_json.keys() == {\"model_path\", \"is_generation\", \"tokenizer_path\"}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get Server Info\n",
    "Gets the server information including CLI arguments, token limits, and memory pool sizes.\n",
    "- Note: `get_server_info` merges the following deprecated endpoints:\n",
    "  - `get_server_args`\n",
    "  - `get_memory_pool_size` \n",
    "  - `get_max_total_num_tokens`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_server_info\n",
    "\n",
    "url = \"http://localhost:30010/get_server_info\"\n",
    "\n",
    "response = requests.get(url)\n",
    "print_highlight(response.text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Health Check\n",
    "- `/health`: Check the health of the server.\n",
    "- `/health_generate`: Check the health of the server by generating one token."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"http://localhost:30010/health_generate\"\n",
    "\n",
    "response = requests.get(url)\n",
    "print_highlight(response.text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"http://localhost:30010/health\"\n",
    "\n",
    "response = requests.get(url)\n",
    "print_highlight(response.text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Flush Cache\n",
    "\n",
    "Flush the radix cache. It will be automatically triggered when the model weights are updated by the `/update_weights` API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# flush cache\n",
    "\n",
    "url = \"http://localhost:30010/flush_cache\"\n",
    "\n",
    "response = requests.post(url)\n",
    "print_highlight(response.text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Update Weights From Disk\n",
    "\n",
    "Update model weights from disk without restarting the server. Only applicable for models with the same architecture and parameter size.\n",
    "\n",
    "SGLang support `update_weights_from_disk` API for continuous evaluation during training (save checkpoint to disk and update weights from disk).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# successful update with same architecture and size\n",
    "\n",
    "url = \"http://localhost:30010/update_weights_from_disk\"\n",
    "data = {\"model_path\": \"meta-llama/Llama-3.2-1B\"}\n",
    "\n",
    "response = requests.post(url, json=data)\n",
    "print_highlight(response.text)\n",
    "assert response.json()[\"success\"] is True\n",
    "assert response.json()[\"message\"] == \"Succeeded to update model weights.\"\n",
    "assert response.json().keys() == {\"success\", \"message\"}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# failed update with different parameter size or wrong name\n",
    "\n",
    "url = \"http://localhost:30010/update_weights_from_disk\"\n",
    "data = {\"model_path\": \"meta-llama/Llama-3.2-1B-wrong\"}\n",
    "\n",
    "response = requests.post(url, json=data)\n",
    "response_json = response.json()\n",
    "print_highlight(response_json)\n",
    "assert response_json[\"success\"] is False\n",
    "assert response_json[\"message\"] == (\n",
    "    \"Failed to get weights iterator: \"\n",
    "    \"meta-llama/Llama-3.2-1B-wrong\"\n",
    "    \" (repository not found).\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Encode (embedding model)\n",
    "\n",
    "Encode text into embeddings. Note that this API is only available for [embedding models](openai_api_embeddings.html#openai-apis-embedding) and will raise an error for generation models.\n",
    "Therefore, we launch a new server to server an embedding model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "terminate_process(server_process)\n",
    "\n",
    "embedding_process = execute_shell_command(\n",
    "    \"\"\"\n",
    "python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct \\\n",
    "    --port 30020 --host 0.0.0.0 --is-embedding\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "wait_for_server(\"http://localhost:30020\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# successful encode for embedding model\n",
    "\n",
    "url = \"http://localhost:30020/encode\"\n",
    "data = {\"model\": \"Alibaba-NLP/gte-Qwen2-7B-instruct\", \"text\": \"Once upon a time\"}\n",
    "\n",
    "response = requests.post(url, json=data)\n",
    "response_json = response.json()\n",
    "print_highlight(f\"Text embedding (first 10): {response_json['embedding'][:10]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Classify (reward model)\n",
    "\n",
    "SGLang Runtime also supports reward models. Here we use a reward model to classify the quality of pairwise generations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "terminate_process(embedding_process)\n",
    "\n",
    "# Note that SGLang now treats embedding models and reward models as the same type of models.\n",
    "# This will be updated in the future.\n",
    "\n",
    "reward_process = execute_shell_command(\n",
    "    \"\"\"\n",
    "python -m sglang.launch_server --model-path Skywork/Skywork-Reward-Llama-3.1-8B-v0.2 --port 30030 --host 0.0.0.0 --is-embedding\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "wait_for_server(\"http://localhost:30030\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "PROMPT = (\n",
    "    \"What is the range of the numeric output of a sigmoid node in a neural network?\"\n",
    ")\n",
    "\n",
    "RESPONSE1 = \"The output of a sigmoid node is bounded between -1 and 1.\"\n",
    "RESPONSE2 = \"The output of a sigmoid node is bounded between 0 and 1.\"\n",
    "\n",
    "CONVS = [\n",
    "    [{\"role\": \"user\", \"content\": PROMPT}, {\"role\": \"assistant\", \"content\": RESPONSE1}],\n",
    "    [{\"role\": \"user\", \"content\": PROMPT}, {\"role\": \"assistant\", \"content\": RESPONSE2}],\n",
    "]\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\")\n",
    "prompts = tokenizer.apply_chat_template(CONVS, tokenize=False)\n",
    "\n",
    "url = \"http://localhost:30030/classify\"\n",
    "data = {\"model\": \"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\", \"text\": prompts}\n",
    "\n",
    "responses = requests.post(url, json=data).json()\n",
    "for response in responses:\n",
    "    print_highlight(f\"reward: {response['embedding'][0]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "terminate_process(reward_process)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Skip Tokenizer and Detokenizer\n",
    "\n",
    "SGLang Runtime also supports skip tokenizer and detokenizer. This is useful in cases like integrating with RLHF workflow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer_free_server_process = execute_shell_command(\n",
    "    \"\"\"\n",
    "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010 --skip-tokenizer-init\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "wait_for_server(\"http://localhost:30010\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-1B-Instruct\")\n",
    "\n",
    "input_text = \"What is the capital of France?\"\n",
    "\n",
    "input_tokens = tokenizer.encode(input_text)\n",
    "print_highlight(f\"Input Text: {input_text}\")\n",
    "print_highlight(f\"Tokenized Input: {input_tokens}\")\n",
    "\n",
    "response = requests.post(\n",
    "    \"http://localhost:30010/generate\",\n",
    "    json={\n",
    "        \"input_ids\": input_tokens,\n",
    "        \"sampling_params\": {\n",
    "            \"temperature\": 0,\n",
    "            \"max_new_tokens\": 256,\n",
    "            \"stop_token_ids\": [tokenizer.eos_token_id],\n",
    "        },\n",
    "        \"stream\": False,\n",
    "    },\n",
    ")\n",
    "output = response.json()\n",
    "output_tokens = output[\"token_ids\"]\n",
    "\n",
    "output_text = tokenizer.decode(output_tokens, skip_special_tokens=False)\n",
    "print_highlight(f\"Tokenized Output: {output_tokens}\")\n",
    "print_highlight(f\"Decoded Output: {output_text}\")\n",
    "print_highlight(f\"Output Text: {output['meta_info']['finish_reason']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "terminate_process(tokenizer_free_server_process)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
